Done by -
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from model import Model
import utils
import wandb
import argparse
from attrdict import AttrDict
import matplotlib.pyplot as plt
import os
print("Module versions:")
print('\n'.join(f'{"> " + m.__name__}: {m.__version__}' for m in globals().values() if getattr(m, '__version__', None)))
plt.style.use('seaborn')
Module versions: > torch: 1.10.0 > wandb: 0.12.11 > argparse: 1.1
The general structure of our data folder is described below. the sequences.txt has been modified and put up on Github in the folder datasets. It contained a video file that was corrupted, namely person01_boxing_d4_uncomp.avi. We remove this from the original sequences.txt for simplicity. Everything else remains the same.
To download KTH action dataset, download the dataset from the website https://www.csc.kth.se/cvap/actions/ or run the script download_kth.sh inside the KTH folder.
Initially, for the first time, change download flag to True in the load_dataset() function in utils.py for the KTH dataset. Once the .pt files have been generated, for the next time, change the download flag to False.
|── scripts ,.....
|── data
|── KTH
|── boxing
|── handclapping
|── handwaving
|── jogging
|── running
|── walking
|── sequences.txt
|── data
|──All .pt files, processed video files
|──MNIST
|── raw
|── processed
|── moving_test.pt
|── moving_train.pt
Path to the dataset folder which contains two seperate folders for MovingMNIST and KTH datasets
dataset_path="data"
To download KTH action dataset, download the dataset from the website https://www.csc.kth.se/cvap/actions/ or run the script download_kth.sh inside the data folder.
mnist_path = os.path.join(dataset_path, "MNIST")
kth_path = os.path.join(dataset_path, "KTH")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
train_set_mnist, test_set_mnist = utils.load_information(mnist_path,dataset='smmnist')
print("---> Total length of training set for Moving-MNIST dataset", len(train_set_mnist))
print("---> Total length of testing set for Moving-MNIST dataset", len(test_set_mnist))
utils.visualise_sample(train_set_mnist,False,device,False)
---> Total length of training set for Moving-MNIST dataset 60000 ---> Total length of testing set for Moving-MNIST dataset 10000
train_set_kth, test_set_kth = utils.load_information(kth_path,dataset='kth')
print("---> Total length of training set for KTH dataset", len(train_set_kth))
print("---> Total length of testing set for KTH dataset", len(test_set_kth))
utils.visualise_sample(train_set_kth,False,device,False)
---> Total length of training set for KTH dataset 23619 ---> Total length of testing set for KTH dataset 12599
train_loader_mmnist = DataLoader(
dataset=train_set_mnist,
batch_size=32,
shuffle=True)
print("---> Length of train dataloader",len(train_loader_mmnist))
test_loader_mmnist = DataLoader(
dataset=test_set_mnist,
batch_size=32,
shuffle=False)
print("---> Length of test dataloader",len(test_loader_mmnist))
print("Data loader for Moving-MNIST dataset ready !")
---> Length of train dataloader 1875 ---> Length of test dataloader 313 Data loader for Moving-MNIST dataset ready !
train_loader_kth = DataLoader(
dataset=train_set_kth,
batch_size=32,
shuffle=True)
print("---> Length of train dataloader",len(train_loader_kth))
test_loader_kth = DataLoader(
dataset=test_set_kth,
batch_size=32,
shuffle=True)
print("---> Length of test dataloader",len(test_loader_kth))
print("Data loader for KTH action dataset ready !")
---> Length of train dataloader 739 ---> Length of test dataloader 394 Data loader for KTH action dataset ready !
Instructions to run training
To train and evaluate the model use the commands listed below:
python scripts/main.py -c dataset_config.yaml --lr_warmup True --add_ssim True --criterion loss_function -s scheduler
-c corresponds to the config file , the two config files kt.yaml and mnist.yaml which are present in the configs folder.
--lr_warmup - this flag is set to True if LR warmup is to be applied to the schedulers that are used else it is set to False.
--add_ssim - this flag is set to True if SSIM is to be used as a combined loss function for training along with MSE or MAE else it is set to False.
--criterion - this corresponds to the loss function criterion which is used for training, it has two values 'mae' or 'mse'. -s corresponds to the type of scheduler that is used,its values are 'exponential' or 'plateau' for the two schedulers used are Exponential LR and ReduceLROnPlateau
This trains the frame prediction model and saves model after every 5th epoch in the model directory.
This also generates folders in the results directory for every log frequency steps. The folders contains the ground truth and predicted frames for the test dataset. These outputs along with loss are written to Weights and Biases as well.
Evaluation:
Once training is completed and the models are saved, the evaluate_model.py file can be used to calculate the following metrics for the model : MSE,MAE,PSNR,SSIM and LPIPS.
This evaluation can be run using the following command:
python scripts/evaluate_model.py -d moving_mnist -mp model_path -s tensor_saving_path
-d corresponds to the datalloader used it ,the values are 'moving_mnist' and 'kth' for the Moving Mnist and KTH Action Dataset.
-mp corresponds to the path along with the model name and type (example: models/mnist/model_50.pth) where the model is stored.
-s corresponds to the path where the tensors for the metrics are stored (example: results_eval/mnist)
Experiments Conducted All experiments are performed for 50 epochs, while the best model is again trained for 100 epochs. The various experiments performed with the Moving Mnist Dataset and their corresponding weights and biases links are provided below:
Moving Mnist trained with MSE and ReduceLRonPlateau scheduler : Wandb Link, Model Link.
Moving Mnist trained with MSE + SSIM loss, ReduceLRonPlateau scheduler and LR warmup : Wandb Link ,Model Link.
Moving Mnist trained with MSE + SSIM loss, Exponential LR scheduler and LR warmup : Wandb Link,Model Link.
Moving Mnist trained with MAE + SSIM loss, Exponential LR scheduler and LR warmup : Wandb Link ,Model Link.
Moving Mnist trained with MAE and ReduceLRonPlateau scheduler : Wandb Link, Model Link.
Moving Mnist trained with MSE + SSIM loss, ReduceLRonPlateau scheduler and LR warmup (Trained without context frame addition) : Wandb Link, Model Link.
Moving Mnist trained with MSE + SSIM loss, ReduceLRonPlateau scheduler and LR warmup (With skip connections in encoder/decoder - vgg blocks) : Wandb Link, Model Link.
Moving Mnist trained with MAE +SSIM and ReduceLRonPlateau scheduler with LR warmup : Wandb Link, Model Link.
Moving Mnist trained with MSE +SSIM and ReduceLRonPlateau scheduler with LR warmup trained for 100 epochs : Wandb Link, Model Link.
sample=next(iter(test_loader_mmnist))
model = Model()
print("Model Loaded")
mnist_mse_plateau, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-MNIST\\mnist_mse+plateau.pth")
mnist_mse_plateau= mnist_mse_plateau.to(device)
Model Loaded
Visualising results on the test dataset. The first line depicts the target sequence and the second depicts the frames our model predicted.
utils.visualise_sample(sample, mnist_mse_plateau, device, True)
model = Model()
print("Model Loaded")
mnist_mse_ssim_plateau_warmup, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-MNIST\\mnist_mse+ssim+plateau+warmup.pth")
mnist_mse_ssim_plateau_warmup= mnist_mse_ssim_plateau_warmup.to(device)
Model Loaded
utils.visualise_sample(sample, mnist_mse_ssim_plateau_warmup, device, True)
model = Model()
print("Model Loaded")
mnist_mse_ssim_explr_warmup, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-MNIST\\mnist_mse+ssim+explr+warmup.pth")
mnist_mse_ssim_explr_warmup= mnist_mse_ssim_explr_warmup.to(device)
Model Loaded
utils.visualise_sample(sample, mnist_mse_ssim_explr_warmup, device, True)
model = Model()
print("Model Loaded")
mnist_mae_plateau, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-MNIST\\mnist_mae+plateau.pth")
mnist_mae_plateau= mnist_mae_plateau.to(device)
Model Loaded
utils.visualise_sample(sample, mnist_mae_plateau, device, True)
Wandb Link</h5>
model = Model()
print("Model Loaded")
mnist_mae_ssim_explr_warmup, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-MNIST\\mnist_mae+ssim+explr+warmup.pth")
mnist_mae_ssim_explr_warmup= mnist_mae_ssim_explr_warmup.to(device)
Model Loaded
utils.visualise_sample(sample, mnist_mae_ssim_explr_warmup, device, True)
model = Model()
print("Model Loaded")
mnist_mae_ssim_plateau_warmup, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-MNIST\\mnist_mae+ssim+plateau+warmup.pth")
mnist_mae_ssim_plateau_warmup= mnist_mae_ssim_plateau_warmup.to(device)
Model Loaded
utils.visualise_sample(sample, mnist_mae_ssim_plateau_warmup, device, True)
The various experiments performed with the KTH Action Dataset and their corresponding weights and biases links are provided below:
KTH Action Dataset trained with MSE loss and ReduceLRonPlateau scheduler : Wandb Link, Model Link.
KTH Action Dataset trained with MSE + SSIM loss, ReduceLRonPlateau scheduler and LR warmup : Wandb Link, Model Link.
KTH Action Dataset trained with MSE + SSIM loss, Exponential LR scheduler and LR warmup : Wandb Link, Model Link.
KTH Action Dataset trained with MAE + SSIM loss, Exponential LR scheduler and LR warmup : Wandb Link, Model Link.
KTH Action Dataset trained with MAE and ReduceLRonPlateau scheduler : Wandb Link, Model Link.
KTH Action Dataset trained with MAE+SSIM , ReduceLRonPlateau scheduler and LR warmup : Wandb Link, Model Link.
test_sample_kth=next(iter(test_loader_kth))
model = Model()
print("Model Loaded")
KTH_mse_plateau, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-KTH\\KTH_mse_plateau.pth")
KTH_mse_plateau= KTH_mse_plateau.to(device)
Model Loaded
Visualising results on the test dataset. The first line depicts the target sequence and the second depicts the frames our model predicted.
utils.visualise_sample(test_sample_kth, KTH_mse_plateau, device, True)
model = Model()
print("Model Loaded")
KTH_mse_ssim_plateau_warmup, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-KTH\\KTH_mse_ssim_plateau_warmup.pth")
KTH_mse_ssim_plateau_warmup= KTH_mse_ssim_plateau_warmup.to(device)
Model Loaded
utils.visualise_sample(test_sample_kth, KTH_mse_ssim_plateau_warmup, device, True)
model = Model()
print("Model Loaded")
KTH_mse_ssim_explr_warmup, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-KTH\\KTH_mse_ssim_explr_warmup.pth")
KTH_mse_ssim_explr_warmup= KTH_mse_ssim_explr_warmup.to(device)
Model Loaded
utils.visualise_sample(test_sample_kth, KTH_mse_ssim_explr_warmup, device, True)
model = Model()
print("Model Loaded")
KTH_mae_plateau, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-KTH\\KTH_mae_plateau.pth")
KTH_mae_plateau= KTH_mae_plateau.to(device)
Model Loaded
utils.visualise_sample(test_sample_kth, KTH_mae_plateau, device, True)
model = Model()
print("Model Loaded")
KTH_mae_ssim_plateau_warmup, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-KTH\\KTH_mae_ssim_plateau_warmup.pth")
KTH_mae_ssim_plateau_warmup= KTH_mae_ssim_plateau_warmup.to(device)
Model Loaded
utils.visualise_sample(test_sample_kth, KTH_mae_ssim_plateau_warmup, device, True)
model = Model()
print("Model Loaded")
KTH_mae_ssim_explr_warmup, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-KTH\\KTH_mae_ssim_explr_warmup.pth")
KTH_mae_ssim_explr_warmup= KTH_mae_ssim_explr_warmup.to(device)
Model Loaded
utils.visualise_sample(test_sample_kth, KTH_mae_ssim_explr_warmup, device, True)
model = Model()
print("Model Loaded")
mnist_mse_ssim_plateau_warmup_100, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Downloads\\model_100_mnist+mse+Ssim+onplateau+warmup.pth")
mnist_mse_ssim_plateau_warmup_100= mnist_mse_ssim_plateau_warmup_100.to(device)
Model Loaded
for i in range(2):
sample=next(iter(test_loader_mmnist))
utils.visualise_sample(sample, mnist_mse_ssim_plateau_warmup_100, device, True)
model = Model()
print("Model Loaded")
KTH_mse_ssim_plateau_warmup_best, _, epoch = utils.loading_model(model, "C:\\Users\\aysha\\Documents\\Models-KTH\\KTH_mse_ssim_plateau_warmup.pth")
KTH_mse_ssim_plateau_warmup_best= KTH_mse_ssim_plateau_warmup_best.to(device)
Model Loaded
for i in range(2):
test_sample_kth=test_sample_kth=next(iter(test_loader_kth))
utils.visualise_sample(test_sample_kth, KTH_mse_ssim_plateau_warmup_best, device, True)
Moving Mnist The best results are obtained when the model is trained with a combined loss, i.e. MSE with SSIM loss along with using the ReduceLROnPlateau Scheduler and a LR warmup that decreases the loss by a factor of 0.1 if the model plataeaus for a total(patience level) of 10 epochs which helps in improving results if the model reaches a plateau stage.
Our best model achieves low SSIM when trained with MSE and SSIM loss functions. Frames predicted with MAE as loss criterion disregard the second digit sometimes which leads to perfect black backgrounds as opposed to the ones generated using MSE loss that have grayish noise in the image. We suspect this to be the reason that SSIM metric performs well with MAE loss in the case of Moving-MNIST dataset.
KTH dataset: The best performing model is trained using MSE and SSIM as a loss function along with ReduceLROnPlateau scheduler. This model achieves the highest SSIM measure of 0.77 and high a high PSNR value too, although not the highest. LPIPS value is also the second best (0.239) for this model. Quali- tatively, this model gives us best results